{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "edca7c01-5f19-4179-8c2a-0f58c058df5c",
   "metadata": {},
   "source": [
    "## 手写数据识别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "db56adc8-4092-4946-8df0-187d96e10378",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 784)\n",
      "(60000,)\n",
      "(10000, 784)\n",
      "(10000,)\n"
     ]
    }
   ],
   "source": [
    "import sys, os\n",
    "sys.path.append(os.pardir)\n",
    "from dataset.mnist import load_mnist\n",
    "\n",
    "(x_train, t_train), (x_test, t_test) = load_mnist(flatten = True, normalize = False)\n",
    "\n",
    "print(x_train.shape)\n",
    "print(t_train.shape)\n",
    "print(x_test.shape)\n",
    "print(t_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a5df01cd-1d94-47f6-94d3-dfcc9984639f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "(784,)\n",
      "(28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcTUlEQVR4nO3df3DU9b3v8dcCyQqaLI0hv0rAgD+wAvEWJWZAxJJLSOc4gIwHf3QGvF4cMXiKaPXGUZHWM2nxjrV6qd7TqURnxB+cEaiO5Y4GE441oQNKGW7blNBY4iEJFSe7IUgIyef+wXXrQgJ+1l3eSXg+Zr4zZPf75vvx69Znv9nNNwHnnBMAAOfYMOsFAADOTwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGG9gFP19vbq4MGDSktLUyAQsF4OAMCTc04dHR3Ky8vTsGH9X+cMuAAdPHhQ+fn51ssAAHxDzc3NGjt2bL/PD7gApaWlSZJm6vsaoRTj1QAAfJ1Qtz7QO9H/nvcnaQFat26dnnrqKbW2tqqwsFDPPfecpk+ffta5L7/tNkIpGhEgQAAw6Pz/O4ye7W2UpHwI4fXXX9eqVau0evVqffTRRyosLFRpaakOHTqUjMMBAAahpATo6aef1rJly3TnnXfqO9/5jl544QWNGjVKL774YjIOBwAYhBIeoOPHj2vXrl0qKSn5x0GGDVNJSYnq6upO27+rq0uRSCRmAwAMfQkP0Geffaaenh5lZ2fHPJ6dna3W1tbT9q+srFQoFIpufAIOAM4P5j+IWlFRoXA4HN2am5utlwQAOAcS/im4zMxMDR8+XG1tbTGPt7W1KScn57T9g8GggsFgopcBABjgEn4FlJqaqmnTpqm6ujr6WG9vr6qrq1VcXJzowwEABqmk/BzQqlWrtGTJEl1zzTWaPn26nnnmGXV2durOO+9MxuEAAINQUgK0ePFi/f3vf9fjjz+u1tZWXX311dq6detpH0wAAJy/As45Z72Ir4pEIgqFQpqt+dwJAQAGoROuWzXaonA4rPT09H73M/8UHADg/ESAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGG9AGAgCYzw/5/E8DGZSVhJYjQ8eElccz2jer1nxk885D0z6t6A90zr06neMx9d87r3jCR91tPpPVO08QHvmUtX1XvPDAVcAQEATBAgAICJhAfoiSeeUCAQiNkmTZqU6MMAAAa5pLwHdNVVV+m99977x0Hi+L46AGBoS0oZRowYoZycnGT81QCAISIp7wHt27dPeXl5mjBhgu644w4dOHCg3327uroUiURiNgDA0JfwABUVFamqqkpbt27V888/r6amJl1//fXq6Ojoc//KykqFQqHolp+fn+glAQAGoIQHqKysTLfccoumTp2q0tJSvfPOO2pvb9cbb7zR5/4VFRUKh8PRrbm5OdFLAgAMQEn/dMDo0aN1+eWXq7Gxsc/ng8GggsFgspcBABhgkv5zQEeOHNH+/fuVm5ub7EMBAAaRhAfowQcfVG1trT755BN9+OGHWrhwoYYPH67bbrst0YcCAAxiCf8W3KeffqrbbrtNhw8f1pgxYzRz5kzV19drzJgxiT4UAGAQS3iAXnvttUT/lRighl95mfeMC6Z4zxy8YbT3zBfX+d9EUpIyQv5z/1EY340uh5rfHk3znvnZ/5rnPbNjygbvmabuL7xnJOmnbf/VeybvP1xcxzofcS84AIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBE0n8hHQa+ntnfjWvu6ap13jOXp6TGdSycW92ux3vm8eeWes+M6PS/cWfxxhXeM2n/ecJ7RpKCn/nfxHTUzh1xHet8xBUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHA3bCjYcDCuuV3H8r1nLk9pi+tYQ80DLdd5z/z1SKb3TNXEf/eekaRwr/9dqrOf/TCuYw1k/mcBPrgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNS6ERLa1xzz/3sFu+Zf53X6T0zfM9F3jN/uPc575l4PfnZVO+ZxpJR3jM97S3eM7cX3+s9I0mf/Iv/TIH+ENexcP7iCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHNSBG3jPV13jNj3rrYe6bn8OfeM1dN/m/eM5L0f2e96D3zm3+7wXsmq/1D75l4BOriu0Fogf+/WsAbV0AAABMECABgwjtA27dv10033aS8vDwFAgFt3rw55nnnnB5//HHl5uZq5MiRKikp0b59+xK1XgDAEOEdoM7OThUWFmrdunV9Pr927Vo9++yzeuGFF7Rjxw5deOGFKi0t1bFjx77xYgEAQ4f3hxDKyspUVlbW53POOT3zzDN69NFHNX/+fEnSyy+/rOzsbG3evFm33nrrN1stAGDISOh7QE1NTWptbVVJSUn0sVAopKKiItXV9f2xmq6uLkUikZgNADD0JTRAra2tkqTs7OyYx7Ozs6PPnaqyslKhUCi65efnJ3JJAIAByvxTcBUVFQqHw9GtubnZekkAgHMgoQHKycmRJLW1tcU83tbWFn3uVMFgUOnp6TEbAGDoS2iACgoKlJOTo+rq6uhjkUhEO3bsUHFxcSIPBQAY5Lw/BXfkyBE1NjZGv25qatLu3buVkZGhcePGaeXKlXryySd12WWXqaCgQI899pjy8vK0YMGCRK4bADDIeQdo586duvHGG6Nfr1q1SpK0ZMkSVVVV6aGHHlJnZ6fuvvtutbe3a+bMmdq6dasuuOCCxK0aADDoBZxzznoRXxWJRBQKhTRb8zUikGK9HAxSf/nf18Y3908veM/c+bc53jN/n9nhPaPeHv8ZwMAJ160abVE4HD7j+/rmn4IDAJyfCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYML71zEAg8GVD/8lrrk7p/jf2Xr9+Oqz73SKG24p955Je73eewYYyLgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNSDEk97eG45g4vv9J75sBvvvCe+R9Pvuw9U/HPC71n3Mch7xlJyv/XOv8h5+I6Fs5fXAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSnwFb1/+JP3zK1rfuQ988rq/+k9s/s6/xuY6jr/EUm66sIV3jOX/arFe+bEXz/xnsHQwRUQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGAi4Jxz1ov4qkgkolAopNmarxGBFOvlAEnhZlztPZP+00+9Z16d8H+8Z+I16f3/7j1zxZqw90zPvr96z+DcOuG6VaMtCofDSk9P73c/roAAACYIEADAhHeAtm/frptuukl5eXkKBALavHlzzPNLly5VIBCI2ebNm5eo9QIAhgjvAHV2dqqwsFDr1q3rd5958+appaUlur366qvfaJEAgKHH+zeilpWVqays7Iz7BINB5eTkxL0oAMDQl5T3gGpqapSVlaUrrrhCy5cv1+HDh/vdt6urS5FIJGYDAAx9CQ/QvHnz9PLLL6u6ulo/+9nPVFtbq7KyMvX09PS5f2VlpUKhUHTLz89P9JIAAAOQ97fgzubWW2+N/nnKlCmaOnWqJk6cqJqaGs2ZM+e0/SsqKrRq1aro15FIhAgBwHkg6R/DnjBhgjIzM9XY2Njn88FgUOnp6TEbAGDoS3qAPv30Ux0+fFi5ubnJPhQAYBDx/hbckSNHYq5mmpqatHv3bmVkZCgjI0Nr1qzRokWLlJOTo/379+uhhx7SpZdeqtLS0oQuHAAwuHkHaOfOnbrxxhujX3/5/s2SJUv0/PPPa8+ePXrppZfU3t6uvLw8zZ07Vz/5yU8UDAYTt2oAwKDHzUiBQWJ4dpb3zMHFl8Z1rB0P/8J7Zlgc39G/o2mu90x4Zv8/1oGBgZuRAgAGNAIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJhI+K/kBpAcPW2HvGeyn/WfkaRjD53wnhkVSPWe+dUlb3vP/NPCld4zozbt8J5B8nEFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GakgIHemVd7z+y/5QLvmclXf+I9I8V3Y9F4PPf5f/GeGbVlZxJWAgtcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgZKfAVgWsme8/85V/8b9z5qxkvec/MuuC498y51OW6vWfqPy/wP1Bvi/8MBiSugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFAPeiILx3jP778yL61hPLH7Ne2bRRZ/FdayB7JG2a7xnan9xnffMt16q857B0MEVEADABAECAJjwClBlZaWuvfZapaWlKSsrSwsWLFBDQ0PMPseOHVN5ebkuvvhiXXTRRVq0aJHa2toSumgAwODnFaDa2lqVl5ervr5e7777rrq7uzV37lx1dnZG97n//vv11ltvaePGjaqtrdXBgwd18803J3zhAIDBzetDCFu3bo35uqqqSllZWdq1a5dmzZqlcDisX//619qwYYO+973vSZLWr1+vK6+8UvX19bruOv83KQEAQ9M3eg8oHA5LkjIyMiRJu3btUnd3t0pKSqL7TJo0SePGjVNdXd+fdunq6lIkEonZAABDX9wB6u3t1cqVKzVjxgxNnjxZktTa2qrU1FSNHj06Zt/s7Gy1trb2+fdUVlYqFApFt/z8/HiXBAAYROIOUHl5ufbu3avXXvP/uYmvqqioUDgcjm7Nzc3f6O8DAAwOcf0g6ooVK/T2229r+/btGjt2bPTxnJwcHT9+XO3t7TFXQW1tbcrJyenz7woGgwoGg/EsAwAwiHldATnntGLFCm3atEnbtm1TQUFBzPPTpk1TSkqKqquro481NDTowIEDKi4uTsyKAQBDgtcVUHl5uTZs2KAtW7YoLS0t+r5OKBTSyJEjFQqFdNddd2nVqlXKyMhQenq67rvvPhUXF/MJOABADK8APf/885Kk2bNnxzy+fv16LV26VJL085//XMOGDdOiRYvU1dWl0tJS/fKXv0zIYgEAQ0fAOeesF/FVkUhEoVBIszVfIwIp1svBGYy4ZJz3THharvfM4h9vPftOp7hn9F+9Zwa6B1r8v4tQ90v/m4pKUkbV7/2HenviOhaGnhOuWzXaonA4rPT09H73415wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBHXb0TFwDUit+/fPHsmn794YVzHWl5Q6z1zW1pbXMcayFb850zvmY+ev9p7JvPf93rPZHTUec8A5wpXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5Geo4cL73Gf+b+z71nHrn0He+ZuSM7vWcGuraeL+Kam/WbB7xnJj36Z++ZjHb/m4T2ek8AAxtXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5Geo58ssC/9X+ZsjEJK0mcde0TvWd+UTvXeybQE/CemfRkk/eMJF3WtsN7pieuIwHgCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBFwzjnrRXxVJBJRKBTSbM3XiECK9XIAAJ5OuG7VaIvC4bDS09P73Y8rIACACQIEADDhFaDKykpde+21SktLU1ZWlhYsWKCGhoaYfWbPnq1AIBCz3XPPPQldNABg8PMKUG1trcrLy1VfX693331X3d3dmjt3rjo7O2P2W7ZsmVpaWqLb2rVrE7poAMDg5/UbUbdu3RrzdVVVlbKysrRr1y7NmjUr+vioUaOUk5OTmBUCAIakb/QeUDgcliRlZGTEPP7KK68oMzNTkydPVkVFhY4ePdrv39HV1aVIJBKzAQCGPq8roK/q7e3VypUrNWPGDE2ePDn6+O23367x48crLy9Pe/bs0cMPP6yGhga9+eabff49lZWVWrNmTbzLAAAMUnH/HNDy5cv129/+Vh988IHGjh3b737btm3TnDlz1NjYqIkTJ572fFdXl7q6uqJfRyIR5efn83NAADBIfd2fA4rrCmjFihV6++23tX379jPGR5KKiookqd8ABYNBBYPBeJYBABjEvALknNN9992nTZs2qaamRgUFBWed2b17tyQpNzc3rgUCAIYmrwCVl5drw4YN2rJli9LS0tTa2ipJCoVCGjlypPbv368NGzbo+9//vi6++GLt2bNH999/v2bNmqWpU6cm5R8AADA4eb0HFAgE+nx8/fr1Wrp0qZqbm/WDH/xAe/fuVWdnp/Lz87Vw4UI9+uijZ/w+4FdxLzgAGNyS8h7Q2VqVn5+v2tpan78SAHCe4l5wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATI6wXcCrnnCTphLolZ7wYAIC3E+qW9I//nvdnwAWoo6NDkvSB3jFeCQDgm+jo6FAoFOr3+YA7W6LOsd7eXh08eFBpaWkKBAIxz0UiEeXn56u5uVnp6elGK7THeTiJ83AS5+EkzsNJA+E8OOfU0dGhvLw8DRvW/zs9A+4KaNiwYRo7duwZ90lPTz+vX2Bf4jycxHk4ifNwEufhJOvzcKYrny/xIQQAgAkCBAAwMagCFAwGtXr1agWDQeulmOI8nMR5OInzcBLn4aTBdB4G3IcQAADnh0F1BQQAGDoIEADABAECAJggQAAAE4MmQOvWrdMll1yiCy64QEVFRfr9739vvaRz7oknnlAgEIjZJk2aZL2spNu+fbtuuukm5eXlKRAIaPPmzTHPO+f0+OOPKzc3VyNHjlRJSYn27dtns9gkOtt5WLp06Wmvj3nz5tksNkkqKyt17bXXKi0tTVlZWVqwYIEaGhpi9jl27JjKy8t18cUX66KLLtKiRYvU1tZmtOLk+DrnYfbs2ae9Hu655x6jFfdtUATo9ddf16pVq7R69Wp99NFHKiwsVGlpqQ4dOmS9tHPuqquuUktLS3T74IMPrJeUdJ2dnSosLNS6dev6fH7t2rV69tln9cILL2jHjh268MILVVpaqmPHjp3jlSbX2c6DJM2bNy/m9fHqq6+ewxUmX21trcrLy1VfX693331X3d3dmjt3rjo7O6P73H///Xrrrbe0ceNG1dbW6uDBg7r55psNV514X+c8SNKyZctiXg9r1641WnE/3CAwffp0V15eHv26p6fH5eXlucrKSsNVnXurV692hYWF1sswJclt2rQp+nVvb6/LyclxTz31VPSx9vZ2FwwG3auvvmqwwnPj1PPgnHNLlixx8+fPN1mPlUOHDjlJrra21jl38t99SkqK27hxY3SfP/3pT06Sq6urs1pm0p16Hpxz7oYbbnA//OEP7Rb1NQz4K6Djx49r165dKikpiT42bNgwlZSUqK6uznBlNvbt26e8vDxNmDBBd9xxhw4cOGC9JFNNTU1qbW2NeX2EQiEVFRWdl6+PmpoaZWVl6YorrtDy5ct1+PBh6yUlVTgcliRlZGRIknbt2qXu7u6Y18OkSZM0bty4If16OPU8fOmVV15RZmamJk+erIqKCh09etRief0acDcjPdVnn32mnp4eZWdnxzyenZ2tP//5z0arslFUVKSqqipdccUVamlp0Zo1a3T99ddr7969SktLs16eidbWVknq8/Xx5XPni3nz5unmm29WQUGB9u/fr0ceeURlZWWqq6vT8OHDrZeXcL29vVq5cqVmzJihyZMnSzr5ekhNTdXo0aNj9h3Kr4e+zoMk3X777Ro/frzy8vK0Z88ePfzww2poaNCbb75puNpYAz5A+IeysrLon6dOnaqioiKNHz9eb7zxhu666y7DlWEguPXWW6N/njJliqZOnaqJEyeqpqZGc+bMMVxZcpSXl2vv3r3nxfugZ9Lfebj77rujf54yZYpyc3M1Z84c7d+/XxMnTjzXy+zTgP8WXGZmpoYPH37ap1ja2tqUk5NjtKqBYfTo0br88svV2NhovRQzX74GeH2cbsKECcrMzBySr48VK1bo7bff1vvvvx/z61tycnJ0/Phxtbe3x+w/VF8P/Z2HvhQVFUnSgHo9DPgApaamatq0aaquro4+1tvbq+rqahUXFxuuzN6RI0e0f/9+5ebmWi/FTEFBgXJycmJeH5FIRDt27DjvXx+ffvqpDh8+PKReH845rVixQps2bdK2bdtUUFAQ8/y0adOUkpIS83poaGjQgQMHhtTr4WznoS+7d++WpIH1erD+FMTX8dprr7lgMOiqqqrcH//4R3f33Xe70aNHu9bWVuulnVMPPPCAq6mpcU1NTe53v/udKykpcZmZme7QoUPWS0uqjo4O9/HHH7uPP/7YSXJPP/20+/jjj93f/vY355xzP/3pT93o0aPdli1b3J49e9z8+fNdQUGB++KLL4xXnlhnOg8dHR3uwQcfdHV1da6pqcm999577rvf/a677LLL3LFjx6yXnjDLly93oVDI1dTUuJaWluh29OjR6D733HOPGzdunNu2bZvbuXOnKy4udsXFxYarTryznYfGxkb34x//2O3cudM1NTW5LVu2uAkTJrhZs2YZrzzWoAiQc84999xzbty4cS41NdVNnz7d1dfXWy/pnFu8eLHLzc11qamp7tvf/rZbvHixa2xstF5W0r3//vtO0mnbkiVLnHMnP4r92GOPuezsbBcMBt2cOXNcQ0OD7aKT4Ezn4ejRo27u3LluzJgxLiUlxY0fP94tW7ZsyP2ftL7++SW59evXR/f54osv3L333uu+9a1vuVGjRrmFCxe6lpYWu0UnwdnOw4EDB9ysWbNcRkaGCwaD7tJLL3U/+tGPXDgctl34Kfh1DAAAEwP+PSAAwNBEgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJj4f4W4/AnknuSPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt \n",
    "\n",
    "from cProfile import label\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "def img_show(img):\n",
    "    # print(img)\n",
    "    pil_img = Image.fromarray(np.uint8(img))\n",
    "    # pil_img.show()\n",
    "    plt.imshow(pil_img)\n",
    "    plt.show()\n",
    "\n",
    "img = x_train[0]\n",
    "label = t_train[0]\n",
    "print(label)\n",
    "\n",
    "print(img.shape)\n",
    "img = img.reshape(28, 28) #读入的数据是一维的\n",
    "print(img.shape)\n",
    "\n",
    "img_show(img)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "867d3ee9",
   "metadata": {},
   "source": [
    "### 神经网络的推理处理\n",
    "\n",
    "神经网络的输入层有784个神经元，输出层有10个神经元\n",
    "输入层的784来自于28*28=784，输出0-9这十个数\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "120450cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from 激活函数 import *\n",
    "from dataset.mnist import load_mnist\n",
    "def get_data():\n",
    "    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)\n",
    "    return x_test, t_test\n",
    "\n",
    "def init_network():\n",
    "    '''读取文件中保存的权重参数，文件中以字典的形式保存了偏重和偏置参数'''\n",
    "    with open(\"./deep_learning_demo/ch03/sample_weight.pkl\", 'rb') as f:\n",
    "        network = pickle.load(f)\n",
    "    \n",
    "    return network\n",
    "\n",
    "def predict(network, x):\n",
    "    '''神经网络识别'''\n",
    "    W1, W2, W3 = network['W1'], network['W2'], network['W3'] \n",
    "    b1, b2, b3 = network['b1'], network['b2'], network['b3']\n",
    "    a1 = np.dot(x, W1) + b1 \n",
    "    z1 = sigmoid(a1)\n",
    "    a2 = np.dot(z1, W2) + b2 \n",
    "    z2 = sigmoid(a2)\n",
    "    a3 = np.dot(z2, W3) + b3 \n",
    "    y = softmax(a3)\n",
    "    return y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f6f80d3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:0.1032\n"
     ]
    }
   ],
   "source": [
    "x, t = get_data()\n",
    "network = init_network()\n",
    "\n",
    "accuracy_cnt = 0\n",
    "for i in range(len(x)):\n",
    "    y = predict(network, x[1])\n",
    "    p = np.argmax(y)\n",
    "    if p == t[i]:\n",
    "        accuracy_cnt += 1\n",
    "\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d65593b",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "699e709c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "d99a3f7b344b3c3107482760db15f42178bfad658d282ab0a919b76809e13cb5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
